Author: Dong Liang
E-mail: ldifer@gmail.com
The novel coronavirus disease (COVID-19) started in late 2019 has developed into a global pandemic, posing an immediate and ongoing threat to the health and economic activities of billions of people today. The severe acute respiratory syndrome coronavirus 2 (SARS-CoV-2), which causes COVID-19, is characterized by rapid and efficient individual to individual transmission with a range of clinical courses including severe acute respiratory distress syndrome, viral pneumonia, mild upper respiratory infection (URIs) and asymptomatic carriers [1]. Covariates associated with worse outcome include hypertension, diabetes, coronary heart disease and older age. [1] Study on COVID-19 cases on the Diamond Princess cruise ship in Japan estimates the proportion of asymptomatic patients to be 17.9% (95% CrI: 15.5-20.2%)[2]. All these present great challenges for prevention and control of the COVID-19 transmission.
There are clear evidences that the SARS-Cov-2 is evolving rapidly. A recent phylogenetic network analysis of 160 SARS-Cov-2 genomes identified three central variants based on amino acid changes [3]. Yet, Tang et al found that two SNPs in strong linkage disequilibrium at location 8,782 (orf1ab: T8517C, synonymous) and 28,144 (ORF8: C251T, S84L) can form haplotypes that classified SARS-CoV-2 viruses into two major lineages (L and S types) [4]. Mutations also frequently occur in the receptor-binding domain (RBD) in the spike protein that mediates infection of human cells [5]. An recent analysis of the viral genomes of 6,000 infected people identified one mutation (named D614G) in the spike protein to be associated with increased virus transmissibility [6]. Obviously, the dynamic evolution of virus genome would have important effects on the spread, pathogenesis and immune intervention of SARS-CoV-2.
Machine learning methods have been successfully applied to classify different types of cancer and identify potentially valuable disease biomarkers [7-14]. In addition, the convolutional neural networks (CNNs) has been developed into the method of choice for medical images recognition and classification. Its special convolution and pooling architectures and parameter sharing mechanism make it computationally more efficient compared to the traditional fully connected neural networks. Albeit with its great popularity in various computer vision tasks, the CNN is less commonly employed in the field of genome sequence analysis. This study attempted to use the state-of-the-art CNN-based autoencoder and perform representation learning on 3161 full-length RNA genome sequences of SARS-Cov-2 collected from across various U.S. states and the world. The model prototype developed in this study could serve as a first step in developing disease risk scoring system in the future.
# from `covid19_util.seq_util import *
# from covid19_util.web_util import *
from covid19_toolkit import *
from covid19_toolkit.seq_util import *
from covid19_toolkit.model_util import *
from covid19_toolkit.model_util2 import *
# Seq processing
import pysam
import vcf
from Bio import SeqIO
from Bio.Seq import Seq
# Data processing
import numpy as np
import pandas as pd
# Plot
from plotnine import *
import pickle
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import style
style.use('ggplot')
%load_ext autoreload
%autoreload 2
virus_info = load_virus_info('nucleotide')
len(virus_info.Country_Region.unique())
virus_info['City'].unique()
virus_info.Length.hist()
plt.title('Length of genomic sequences (SARS-Cov-2)')
top10 = virus_info.groupby('Country_Region').agg(len)['Length'].sort_values(ascending=False).head(10).index.tolist()
virus_info_R = virus_info.loc[virus_info['Country_Region'].isin(top10), :]
(
ggplot(aes(x='Country_Region', y='Length'), data=virus_info_R) +
geom_boxplot(alpha = 0.5) +
geom_jitter(alpha=0.2) +
theme(
# figure_size=(10,6),
legend_key=element_blank(),
axis_text_x = element_text(rotation=45, hjust=1.),
)
)
g = (
ggplot(aes(x='Length'), data=virus_info_R) +
geom_histogram(aes(fill = 'Country_Region'), alpha = 0.9, bins = 40)+
facet_wrap('~Country_Region', scales = "free_y") +
theme(
figure_size=(10,6),
legend_key=element_blank(),
axis_text_x = element_text(rotation=45, hjust=1.),
)
)
g + ggtitle('Length distribution of SARS-Cov-2 in top10 Country/Region')
(
ggplot(aes(x='Collection_Date', y='Length'), data=virus_info_R) +
geom_point(aes(color = 'Country_Region')) +
stat_smooth(method='loess') +
scale_x_date(date_breaks = "1 month", date_labels = "%b %Y") +
ylab('SARS-CoV-2 length (bp)') +
facet_wrap('~Country_Region') +
theme(
figure_size=(6, 4),
legend_key=element_blank(),
axis_text_x = element_text(rotation=45, hjust=1.),
)
)
virus_info_R1.groupby('Country_Region').agg([mean, std]).loc[:, 'Length']
count_by_country = pd.DataFrame(virus_info.groupby('Country_Region').agg(len)['Length'].sort_values(ascending=False).head(10))
count_by_country.reset_index(inplace=True)
count_by_country.columns = ['Country_Region', 'Count']
manufacturer_list = count_by_country.Country_Region[::-1]
(
ggplot(aes(x='Country_Region', y='Count'), data=count_by_country) +
geom_bar(stat = 'identity', size=10) +
scale_x_discrete(limits=manufacturer_list) +
coord_flip()+
theme(
# figure_size=(10,6),
legend_key=element_blank(),
axis_text_x = element_text(rotation=45, hjust=1.),
)
)
plt.subplot(121)
virus_info.groupby('Country_Region').agg(len)['Length'].sort_values(ascending=False).head(10).plot.bar()
plt.subplot(122)
virus_info.groupby('City').agg(len)['Length'].sort_values(ascending=False).head(10).plot.bar()
plt.xlabel('City/US state')
plt.gcf().set_size_inches(10, 4)
coding_seqs = load_virus_seq('./nucleotide.fasta')
covid19_seqs = [COVID19(id, seq) for id, seq in coding_seqs.items()]
covid19_complete = {seq.id: seq(ORF = 'complete', return_residual = False) for seq in covid19_seqs }
covid19_dataset = zero_padding(covid19_complete)
spike = {seq.id: seq(ORF = 'spike', return_residual = True) for seq in covid19_seqs}
spike['MT509460']
# Reference seq
coding_seqs['NC_045512']['Severe acute respiratory syndrome coronavirus 2 isolate Wuhan-Hu-1'][:500]
seq_NC = covid19_complete['NC_045512']
seq_NC.shape
# Data loading
complete_seqs = load_virus_seq('./nucleotide.fasta')
covid19_dataset = preprocessing(complete_seqs)
# Minibatching - custom training
ids = [id for id, seq in covid19_dataset.items()]
seq = [tf.cast(seq, tf.float32) for id, seq in covid19_dataset.items()]
train_dataset = tf.data.Dataset.from_tensor_slices((seq, seq)).cache().shuffle(7000).batch(32)
# Training data - keras training
ids = [id for id, seq in covid19_dataset.items()]
ids_march = np.array(ids)[pd.Series(ids).isin(virus_info[virus_info['Collection_Date'].dt.month == 3]['Accession'])]
seq = np.array([seq.astype(np.float32) for id, seq in covid19_dataset.items() if id in ids_march])
# seq.shape
seq_all = np.array([seq.astype(np.float32) for id, seq in covid19_dataset.items()])
The autoencoder was build based on a 2D convolutional neural network with architectures of a range of combinations of maxpooling, dropout and convolutional layers. It turns out the best performance was achieved using the simple settings as shown below. However, the 2D convolutional neural network architecture did yield superior classifciation performance as compared to the 1D CNN.
def train(model, epoch, loss_fn, train_dataset, \
optimizer_fn, learning_rate, print_every = 10, \
manager = None, **kwargs):
# Optimizer
optimizer = optimizer_fn(learning_rate = learning_rate)
total_loss_train = tf.keras.metrics.Mean(name = 'total_loss_train')
# accuracy_train = tf.keras.metrics.BinaryAccuracy(name='accuracy_train')
# Initialize metrics
total_loss_train.reset_states()
for k in range(epoch):
for _, (x, y) in enumerate(train_dataset):
with tf.GradientTape() as type:
# Model prediction
yhat = model(x)
# Computer loss
total_loss = loss_fn(y, yhat)
loss_mean = tf.reduce_mean(total_loss) # Average batch loss
# Record metrics
total_loss_train.update_state(total_loss)
# Calculate gradients of weights and biases
grad = type.gradient(loss_mean, model.trainable_weights)
# Apply gradients
optimizer.apply_gradients(zip(grad, model.trainable_weights))
if k % print_every == 0 or k == epoch - 1:
print('Epoch', k + 1)
if not kwargs.keys():
print(
f"Total_loss_train:{(total_loss_train.result()): {0}.{4}f}"
)
else:
# Initialize metrics
total_loss_val.reset_states()
# Calculate validation metrics
for (x, y) in zip(test_dataset[0], test_dataset[1]):
yhat = model(x)
loss = loss_fun(y, yhat)
loss_mean = tf.reduce_mean(loss)
loss_val.update_state(loss_mean)
accuracy_val.update_state(y, yhat)
print(
f"Train loss:{loss_train.result(): {0}.{4}f}",
f"Train Acc:{(accuracy_train.result() * 100): {0}.{2}f}",
f"Val Loss:{loss_val.result(): {0}.{4}f}",
f"Val Acc:{(accuracy_val.result() * 100): {0}.{2}f}",
)
return model # , loss_train, loss_val
# Define parameters for train_model function
# loss_fn = tf.nn.sigmoid_cross_entropy_with_logits
# CNN autoencoder parameters
params = {
'conv_filters': [32, 64],
'conv_kernel_size': [5, 5],
'conv_stride': [1, 1],
'convTranspose_filters': [64, 32, 1],
'convTranspose_kernel_size': [5, 5, 5],
'convTranspose_stride': [1, 1, 1],
'latent_dim': 4
}
# Hyperparameters
model = Autoencoder_cnn3(**params)
loss_fn = tf.keras.losses.BinaryCrossentropy()
epoch = 1
optimizer_fn = tf.optimizers.Adam # tf.train.AdamOptimizer
learning_rate = 3e-4 # 3e-4 # 5e-3
# Train dataset
train_dataset = train_dataset # , X_minibatch10, S_mag_minibatch10p
trained_model = train(model, epoch, loss_fn, train_dataset, \
optimizer_fn, learning_rate = learning_rate, \
print_every = 10, manager = None,
# valid_dataset = valid_dataset,
# X_validation=X_valid, y_validation=y_valid
)
params = {
'conv_filters': [32, 64],
'conv_kernel_size': [5, 5],
'conv_stride': [1, 1],
'convTranspose_filters': [64, 32, 1],
'convTranspose_kernel_size': [5, 5, 5],
'convTranspose_stride': [1, 1, 1],
'latent_dim': 3
}
# Hyperparameters
model_keras = Autoencoder_cnn3(**params)
# model_keras(tf.reshape(seq[0], (1, 30018, 5)))
# model.summary()
model_keras.compile(optimizer = Adam(lr=0.01), loss = tf.keras.losses.BinaryCrossentropy(), metrics=['accuracy'])
model_keras.fit(seq, seq, epochs = 3, batch_size = 32) # , batch_size = 128
RERUN = False
if RERUN:
pred = np.empty((2868, 3))
pred500 = model_keras.encoder(seq_all[:500])
pred500_1000 = model_keras.encoder(seq_all[500:1000])
pred1000_1500 = model_keras.encoder(seq_all[1000:1500])
pred1500_2000 = model_keras.encoder(seq_all[1500:2000])
pred12000_ = model_keras.encoder(seq_all[2000:])
pred = np.concatenate([pred500, pred500_1000, pred1000_1500, pred1500_2000, pred12000_], axis = 0)
pickle.dump( pred, open( "prediction_all.pkl", "wb" ) )
else:
pred = pickle.load(open( "prediction_all.pkl", "rb" ))
pred_df = pd.DataFrame(np.c_[virus_info.values, pred], columns = list(virus_info.columns) + ['d1', 'd2', 'd3'])
pred_df.d1 = pred_df.d1.astype('float')
pred_df.d2 = pred_df.d2.astype('float')
pred_df.d3 = pred_df.d3.astype('float')
# By month
pred_USA = pred_df.query("Country_Region == 'USA'")
pred_USA['Collection_month'] = pred_USA.Collection_Date.dt.month
pred_USA['Collection_month'].replace(to_replace = {1: "1, Jan", 2: "2, Feb", 3: '3, Mar', 4: '4, Apr', 5: '5, May'}, inplace = True)
# By state
gt = pred_USA.groupby('City').count()['Accession'] > 100
city_list = np.array(pred_USA.groupby('City').count()['Accession'].index[gt])
pred_USA_select = pred_USA.loc[pred_USA.City.isin(city_list), :]
# By location
pred_df_top10 = pred_df.loc[pred_df.Country_Region.isin(top10), :]
from mpl_toolkits.mplot3d import Axes3D
%matplotlib inline
# %matplotlib widget
d = pred_df.loc[:, ['d1', 'd2', 'd3']].values.astype('float')
fig = plt.figure(figsize=(8,7))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(d[:, 0], d[:, 1], d[:, 2], c='r', marker='o')
ref = pred_df.loc[pred_df.Accession == 'NC_045512', :]
top4 = pred_df.loc[pred_df.Country_Region.isin(['CHINA', 'USA', 'AUSTRALIA', 'GERMANY']), :]
(
ggplot(aes(x='d1', y='d2', color = 'Country_Region'), data=top4) + # pred_df_top10
geom_point() +
geom_point(aes(x='d1', y='d2'), color = 'black', shape = '*', size= 10, alpha = 0.3, data = ref) +
# scale_y_discrete(minor_breaks=[]) +
# scale_x_discrete(limits=manufacturer_list) +
# coord_flip()+
xlab('Dimension 1') +
ylab('Dimension 2') +
facet_wrap('~Country_Region') +
theme(
figure_size=(10,6),
legend_key=element_blank(),
legend_position = "top",
legend_title = element_blank(),
axis_text_x = element_text(rotation=45, hjust=1.)
)
)
%matplotlib inline
# By state
gt = pred_USA.groupby('City').count()['Accession'] > 43
city_list = np.array(pred_USA.groupby('City').count()['Accession'].index[gt])
pred_USA_select = pred_USA.loc[pred_USA.City.isin(city_list), :]
(
ggplot(aes(x='d1', y='d2', color = 'City'), data=pred_USA_select) +
geom_point() +
xlab('Dimension 1') +
ylab('Dimension 2') +
facet_wrap('~City') +
theme(
figure_size=(10,6),
legend_key=element_blank(),
legend_position = "top",
legend_title = element_blank(),
axis_text_x = element_text(rotation=45, hjust=1.)
)
)
%matplotlib inline
gt = pred_USA.groupby('City').count()['Accession'] > 100
city_list = np.array(pred_USA.groupby('City').count()['Accession'].index[gt])
pred_USA_select = pred_USA.loc[pred_USA.City.isin(city_list), :]
(
ggplot(aes(x='d1', y='d2', z = 'd3', color = 'Collection_month'), data=pred_USA_select) +
geom_point() +
facet_grid('City ~ Collection_month') +
theme(
figure_size=(12,8),
legend_key=element_blank(),
legend_position = "top",
legend_title = element_blank(),
axis_text_x = element_text(rotation=45, hjust=1.)
)
)